#!/usr/bin/env python
# coding: utf-8
"""
Computes images from approximation modules for immuno.

Input parameters :
- idataset = 1 or 2 : the dataset index
- res: int : the grid resolution (on which we compute the alpha complex) of the module
- num : int : the number of modules approximating the final one
"""
# # Prerequises



# In[13]:


print("Loading dependencies", flush=True)
import numpy as np
import matplotlib.pyplot as plt
from mma import splx2bf, approx, from_dump
from classif_helper import *
import pickle
from sys import argv

idataset = int(argv[1])
if (idataset != 1 and idataset !=2):
	print("bad argument")
	exit()
res = int(argv[2])
num = int(argv[3])
print("Arguments :", *argv)


with open(f"modules/immuno/iterator{idataset}_res{res}_num{num}.np", "rb") as f:
	iterator = np.load(f)


# In[21]:
params = {}

params["bandwidth"]=2_000
params["dimension"]=1
params["resolution"]=[200,200]
params["normalize"] = 1
params["ps"] = [0,1,2,np.inf]
params["cb"] = 1
print(params)


# distance between images
distances = [
	lambda x,y : np.square(x-y).mean(),
	lambda x,y : np.square(x-y).mean()/y.max(),
	lambda x,y : np.abs(x-y).max(),
	lambda x,y : np.abs(x-y).max()/y.max(),
]
distances_names=["L2 norm", "scaled L2 norm", "sup norm", "scaled sup norm"]

print("Computing images...", flush=True)

print("- last imgs...", flush=True)
last_mod = from_dump(pickle.load(open(f"modules/immuno/module{idataset}_res{res}_num{num}_{len(iterator)-1}.pkl", "rb")))
last_imgs = []
for p in params["ps"]:
	plt.figure()
	last = last_mod.image(p=p,plot=True,**params)
	last_imgs.append(last)
	plt.savefig(f"test_immuno{idataset}_bdw{params['bandwidth']}_res{res}_p{p}.png", dpi=200)
	plt.clf()
del last_mod

print("- approximation images...", flush=True)
errors = np.zeros(shape=(len(params["ps"]), len(iterator), len(distances)))
for j,_ in tqdm(enumerate(iterator), total=len(iterator)):
	current_mod = from_dump(pickle.load(open(f"modules/immuno/module{idataset}_res{res}_num{num}_{j}.pkl", "rb")))
	for i,p in enumerate(params["ps"]):
		current_img = current_mod.image(p=p, plot=False,**params)
		for k,d in enumerate(distances):
			errors[i,j,k] = d(current_img, last_imgs[i])

print("Saving errors...", flush=True)
with open(f"errors/immuno/errors{idataset}_bdw{params['bandwidth']}_res{res}_num{num}.pkl", 'wb') as file:
	pickle.dump(errors, file)

print("Saving plots...", flush=True)
for k,_ in enumerate(distances):
	plt.figure()
	for i,p in enumerate(params["ps"]):
		plt.plot(iterator[:-1], errors[i,:-1,k], label=f"p={p}")
	plt.xlabel("Number of points")
	plt.ylabel(distances_names[k])
	plt.legend()
	plt.savefig(f"images/immuno/plot{idataset}_bdw{params['bandwidth']}_{distances_names[k]}_cv_res{res}_num{num}.svg")
	plt.clf()

print("Done !")




